from collections import defaultdict, OrderedDict
from dataclasses import dataclass
from typing import Union

import numpy as np

try:
    from openvino import Model, Node
    from openvino.op import Parameter, Constant
    from openvino.utils.types import get_element_type
    import openvino.opset12 as opset
except ImportError:
    from openvino.runtime import Model, Node
    from openvino.runtime.op import Parameter, Constant
    import openvino.runtime.opset12 as opset
    from openvino.runtime.utils.types import get_element_type

import openvino as ov
from tqdm.auto import tqdm

OPERATION_TYPE_MAP = {"MatMul": opset.matmul, "Convolution": opset.convolution}

ORIGINAL_PRECISION_RT_INFO_NAME = "precise_0"


@dataclass
class TrackedNodeInfo:
    """
    Data associated with a node tracked for upcasting
    """

    node: Node  # Target node to track
    snr: float = None  # SNR of the target node
    input_nodes: list[Node] = None  # Input nodes of the target node
    result_node: Node = None  # Result node of the target node
    input_result_nodes: dict[Node, Node] = None  # Result nodes of non-const inputs of the target node
    node_value_full_precision: np.ndarray = None  # Result of the node in full precision
    node_value_half_precision: np.ndarray = None  # Result of the node in half precision
    input_values_full_precision: np.ndarray = None  # Results of the target node inputs in full precision


def partially_upcast_nodes_to_fp32(
    orig_model: Model,
    example_input: Union[list, dict],
    half_type: str = "f16",
    batch_size: int = 50,
    operation_types: list[str] = None,
    upcast_ratio: float = 0.1,
    verbose: bool = False,
) -> Model:
    """
    Transform a model to upcast some nodes to be executed in full precision instead of half precision. These nodes are
    marked with runtime info flag.
    Nodes are selected based on Signal-to-Noise Ratio (SNR) metric: upcast_ratio fraction of tracked nodes with the
    lowest SNR are marked for full precision execution.

    Note: Input model should have fp16 weights (i.e. saved with compress_to_fp16=True) in order to conserve
    calibration memory.

    :param orig_model: Model to process
    :param example_input: Example input for model inference
    :param half_type: Either "f16" or "bf16"
    :param batch_size: Number of nodes to process together during a single model inference. The lower the value is,
        the less memory footprint is, but the larger is the processing time. The value of -1 is used to disable
        batching.
    :param operation_types: Types of operations to consider. If None, MatMuls and Convolutions are considered.
    :param upcast_ratio: Fraction of nodes to upcast (with the lowest SNR). 0 - do not upcast anything, 1 - upcast every
        operation of the given types.
    :param verbose: If True, prints progress output.
    :return: Upcasted OV model with some nodes marked for full precision execution.
    """
    if half_type not in ("f16", "bf16"):
        raise ValueError(f"Half type must be either 'f16' or 'bf16'. Got {half_type}.")
    if half_type == "bf16":
        print("Warning! Calibration currently does not provide any improvement for bf16 type.")
    if operation_types is None:
        operation_types = ["MatMul", "Convolution"]
    for op_type in operation_types:
        if op_type not in OPERATION_TYPE_MAP:
            raise ValueError(f"Operation type must be one of the following {list(OPERATION_TYPE_MAP.keys())}. " f"Got {op_type}.")
    if verbose:
        print(f"The following operation types will be considered: {operation_types}")

    device = "GPU" if half_type == "f16" else "CPU"

    nodes_to_track_names = get_nodes_to_track(orig_model, operation_types)
    if len(nodes_to_track_names) == 0:
        if verbose:
            print("Warning. Not found any operations of the given type(s).")
        return orig_model.clone()

    node_names_and_snrs = []
    batch_size = len(nodes_to_track_names) if batch_size == -1 or batch_size > len(nodes_to_track_names) else batch_size
    if verbose:
        print("Started upcasting")
    for i in tqdm(
        range(0, len(nodes_to_track_names), batch_size),
        desc="Processing batches",
        disable=not verbose,
    ):
        if upcast_ratio == 0.0 or upcast_ratio == 1.0:
            continue
        model = orig_model.clone()
        name_to_node_map = {op.get_friendly_name(): op for op in model.get_ops()}
        nodes_to_track_batch = [TrackedNodeInfo(name_to_node_map[node_name]) for node_name in nodes_to_track_names[i : i + batch_size]]

        # Add outputs for non-constant inputs of tracked nodes
        insert_outputs_for_tracked_ops(model, nodes_to_track_batch)
        # Infer model to collect tracked operation results and results of their inputs in full precision
        infer_full_net(nodes_to_track_batch, model, example_input)
        # Infer nodes in half precision one by one using full precision inputs, collect half precision results
        infer_nodes(nodes_to_track_batch, device, half_type)

        # Compute operation SNR based on full precision and half precision results
        for node_info in nodes_to_track_batch:
            try:
                snr = compute_snr(
                    node_info.node_value_full_precision,
                    node_info.node_value_half_precision,
                )
            except RuntimeError as e:
                # TODO: find the reason behind this
                if node_info.node.get_type_name() in [
                    "Add",
                    "Concat",
                ] and "Shape mismatch" in str(e):
                    print(
                        "Warning.",
                        str(e),
                        node_info.node.get_friendly_name(),
                        node_info.node.get_type_name(),
                        [(inp_node.get_friendly_name(), inp_node.get_type_name()) for inp_node in node_info.input_nodes],
                    )
                    snr = np.finfo(np.float32).max
                else:
                    raise e
            node_names_and_snrs.append((node_info.node.get_friendly_name(), snr))

    if upcast_ratio != 0.0 and upcast_ratio != 1.0:
        node_names_and_snrs = sorted(node_names_and_snrs, key=lambda it: it[1])
        node_names, node_snrs = tuple(zip(*node_names_and_snrs))

        n_nodes = len(node_names)
        nodes_to_upcast_cnt = int(np.ceil(n_nodes * upcast_ratio))
        node_to_upcast_names = node_names[:nodes_to_upcast_cnt]

        if verbose:
            snr_quantile = node_snrs[nodes_to_upcast_cnt - 1]
            print(f"Upcasted {nodes_to_upcast_cnt}/{n_nodes} nodes with SNR less than {snr_quantile:.2f}.")
            for node_name, node_snr in node_names_and_snrs[:nodes_to_upcast_cnt]:
                print(node_name, node_snr)
    elif upcast_ratio == 0.0:
        if verbose:
            print("Skipping algorithm because upcast ratio equals 0.0. Nothing to upcast.")
        node_to_upcast_names = []
    else:
        if verbose:
            print("Skipping algorithm because upcast ratio equals 1.0. Upcasting all nodes of the given type(s).")
        node_to_upcast_names = nodes_to_track_names

    new_model = orig_model.clone()
    mark_nodes_to_upcast_to_fp32(new_model, node_to_upcast_names)
    return new_model


def get_nodes_to_track(model: Model, operation_types: list[str]) -> list:
    nodes_to_track = []
    for i, op in enumerate(model.get_ordered_ops()):
        if op.get_type_name() in operation_types and all(
            map(
                lambda input: input.get_node().get_type_name() != "Result",
                op.output(0).get_target_inputs(),
            )
        ):
            nodes_to_track.append(op.get_friendly_name())
    return nodes_to_track


def insert_outputs_for_tracked_ops(model: Model, nodes_to_track: list[TrackedNodeInfo]) -> None:
    node_to_output_map = OrderedDict()
    node_to_node_info_map = defaultdict(list)
    for node_info in nodes_to_track:
        node = node_info.node
        node_to_node_info_map[node].append((node_info, "parent"))  # add as a parent node
        if node not in node_to_output_map:
            node_to_output_map[node] = node.output(0)
        node_info.input_nodes = []
        for inp_value in node.input_values():
            child_node = inp_value.get_node()
            node_info.input_nodes.append(child_node)
            # Do not add outputs for constant nodes
            if child_node.get_type_name() != "Constant" and not is_constant_path(child_node):
                node_to_node_info_map[child_node].append((node_info, "child"))  # add as a child node
                if child_node not in node_to_output_map:
                    node_to_output_map[child_node] = child_node.output(0)

    outputs = model.add_outputs(list(node_to_output_map.values()))
    for output, node in zip(outputs, node_to_output_map.keys()):
        # Value matching will be done later based on result node friendly names
        result_node = output.node
        for node_info, parent_label in node_to_node_info_map[node]:
            is_parent = parent_label == "parent"
            if is_parent:
                node_info.result_node = result_node
            else:
                if node_info.input_result_nodes is None:
                    node_info.input_result_nodes = {}
                node_info.input_result_nodes[node] = result_node


def get_const_value_from_ovmodel(node: Union[Constant, Node]) -> np.ndarray:
    if node.get_type_name() == "Constant":
        assert node.get_element_type() not in [
            ov.Type.f16,
            ov.Type.bf16,
        ], f"{node.get_friendly_name()}, {node.get_element_type()}"
        return node.get_data()
    elif is_constant_path(node):
        # If model is compressed and constant values flow through decompression convert
        const_node = node.input_value(0).get_node()
        assert const_node.get_type_name() == "Constant"
        assert const_node.get_element_type().is_real(), const_node.get_element_type()
        return node.input_value(0).get_node().get_data()  # return f16 weight
    else:
        raise Exception(f"Cannot get const values from ov.Model for {node.get_friendly_name()} with type {node.get_type_name()}")


def is_constant_path(node: Node) -> bool:
    if node.get_type_name() != "Convert":
        return False
    if len(node.get_rt_info()["is_decompression_0"].aslist()) > 0:
        return True
    if node.input_value(0).get_node().get_type_name() == "Constant":
        return True
    return False


def infer_full_net(nodes_to_track: list[TrackedNodeInfo], orig_model: Model, example_inputs: list) -> None:
    core = ov.Core()
    exec_net = core.compile_model(orig_model, "CPU", config={"INFERENCE_PRECISION_HINT": "f32"})
    request = exec_net.create_infer_request()
    results = request.infer(example_inputs, share_inputs=True, share_outputs=True)

    friendly_name_to_result_map = {}
    for i, (key, val) in enumerate(results.items()):
        result_node = key.node
        friendly_name_to_result_map[result_node.get_friendly_name()] = val

    for node_info in nodes_to_track:
        node_info.node_value_full_precision = friendly_name_to_result_map[node_info.result_node.get_friendly_name()]
        node_info.input_values_full_precision = []
        for input_node in node_info.input_nodes:
            if input_node.get_type_name() == "Constant" or is_constant_path(input_node):
                # If input is constant, retrieve its value from model
                input_value = get_const_value_from_ovmodel(input_node)
            else:
                # If input is not constant, retrieve its input from inference results
                input_value = friendly_name_to_result_map[node_info.input_result_nodes[input_node].get_friendly_name()]
            node_info.input_values_full_precision.append(input_value)


def infer_nodes(nodes_to_track: list[TrackedNodeInfo], device: str, precision: str) -> None:
    for node_info in nodes_to_track:
        infer_tracked_op(node_info, device, precision)


def infer_tracked_op(node_info: TrackedNodeInfo, device: str, precision: str) -> None:
    parameters = []
    inputs = []
    input_values = node_info.input_values_full_precision
    for input_value in input_values:
        parameter = Parameter(get_element_type(input_value.dtype), ov.PartialShape(input_value.shape))
        if input_value.dtype == np.float16:
            # Convert f16 weight to f32
            convert_node = opset.convert(parameter, "f32")
            inputs.append(convert_node)
        else:
            inputs.append(parameter)
        parameters.append(parameter)

    node = node_info.node
    try:
        call_attributes = node.get_attributes()
        # Below are some op workarounds
        if node.get_type_name() == "Divide" and "m_pythondiv" in call_attributes:
            del call_attributes["m_pythondiv"]
        if node.get_type_name() == "Broadcast" and "mode" in call_attributes:
            call_attributes["broadcast_spec"] = call_attributes["mode"]
            del call_attributes["mode"]
        if node.get_type_name() == "Concat":
            new_op = OPERATION_TYPE_MAP[node.get_type_name()](inputs, **call_attributes)
        else:
            new_op = OPERATION_TYPE_MAP[node.get_type_name()](*inputs, **call_attributes)

        ov_model = ov.Model([new_op], parameters=parameters)
        exec_net = ov.Core().compile_model(ov_model, device, config={"INFERENCE_PRECISION_HINT": precision})
        request = exec_net.create_infer_request()
        result = request.infer(input_values, share_inputs=True, share_outputs=True)
    except Exception as e:
        print(
            "Operation inference error",
            node.get_type_name(),
            node.get_friendly_name(),
            inputs,
            node.get_attributes(),
        )
        raise e

    node_info.node_value_half_precision = result[0]
    assert len(result) == 1


def is_model_partially_upcasted(model) -> bool:
    for node in model.get_ordered_ops():
        if node.get_type_name() not in OPERATION_TYPE_MAP.keys():
            continue
        if ORIGINAL_PRECISION_RT_INFO_NAME in node.get_rt_info().keys():
            return True
    return False


def mark_nodes_to_upcast_to_fp32(model: ov.Model, nodes_with_errors: list[str]) -> None:
    nodes_to_mark = set(nodes_with_errors)
    for node in model.get_ordered_ops():
        if node.get_friendly_name() in nodes_to_mark:
            node.get_rt_info()[ORIGINAL_PRECISION_RT_INFO_NAME] = ""
            nodes_to_mark.remove(node.get_friendly_name())
    assert len(nodes_to_mark) == 0, nodes_to_mark


def compute_snr(x, y):
    # x -- original value (full precision), y -- value with noise (half precision)

    x, y = x.astype(np.float32), y.astype(np.float32)
    max_value = np.finfo(np.float32).max

    if np.prod(x.shape) != np.prod(y.shape):
        raise RuntimeError(f"Shape mismatch: {x.shape}, {y.shape}.")

    x = np.nan_to_num(x, posinf=max_value)
    y = np.nan_to_num(y, posinf=max_value)

    Ps = np.linalg.norm(x)
    Pn = np.nan_to_num(np.linalg.norm(x - y), posinf=max_value)

    if Ps == Pn == 0.0:
        return max_value

    snr = np.nan_to_num(20 * np.log10(Ps / Pn), posinf=max_value)

    return snr
